import torch
import numpy as np


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, input_data, label):
        assert len(input_data) == len(label)
        self.data = input_data
        self.label = label

    def __getitem__(self, item):
        return self.data[item], self.label[item]

    def __len__(self):
        return len(self.data)
